{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DT.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ajhk76kTn8L4",
        "colab_type": "text"
      },
      "source": [
        "# 第5章 决策树"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-iCRuHmOn8L5",
        "colab_type": "text"
      },
      "source": [
        "- ID3（基于信息增益）\n",
        "- C4.5（基于信息增益比）\n",
        "- CART 二叉决策树（gini指数）"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EbDeryI9n8L6",
        "colab_type": "text"
      },
      "source": [
        "#### entropy：$H(x) = -\\sum_{i=1}^{n}p_i\\log{p_i}$\n",
        "\n",
        "#### conditional entropy: $H(X|Y)=\\sum{P(X|Y)}\\log{P(X|Y)}$\n",
        "\n",
        "#### information gain : $g(D, A)=H(D)-H(D|A)$\n",
        "\n",
        "#### information gain ratio: $g_R(D, A) = \\frac{g(D,A)}{H_{A}(D)}$\n",
        "\n",
        "#### gini index:$Gini(D)=\\sum_{k=1}^{K}p_k\\log{p_k}=1-\\sum_{k=1}^{K}p_k^2$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qemEGcJ7n8L6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline\n",
        "\n",
        "from sklearn.datasets import load_iris\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n",
        "from collections import Counter\n",
        "import math\n",
        "from math import log\n",
        "\n",
        "import pprint"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XyUeXAC_n8L-",
        "colab_type": "text"
      },
      "source": [
        "### 例 5.1"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YtNCGcaHn8L_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def create_data():\n",
        "    datasets = [['青年', '否', '否', '一般', '否'],\n",
        "               ['青年', '否', '否', '好', '否'],\n",
        "               ['青年', '是', '否', '好', '是'],\n",
        "               ['青年', '是', '是', '一般', '是'],\n",
        "               ['青年', '否', '否', '一般', '否'],\n",
        "               ['中年', '否', '否', '一般', '否'],\n",
        "               ['中年', '否', '否', '好', '否'],\n",
        "               ['中年', '是', '是', '好', '是'],\n",
        "               ['中年', '否', '是', '非常好', '是'],\n",
        "               ['中年', '否', '是', '非常好', '是'],\n",
        "               ['老年', '否', '是', '非常好', '是'],\n",
        "               ['老年', '否', '是', '好', '是'],\n",
        "               ['老年', '是', '否', '好', '是'],\n",
        "               ['老年', '是', '否', '非常好', '是'],\n",
        "               ['老年', '否', '否', '一般', '否'],\n",
        "               ]\n",
        "    labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']\n",
        "    # 返回数据集和每个维度的名称\n",
        "    return datasets, labels"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Ji3uUZS-n8MB",
        "colab_type": "code",
        "outputId": "bec9dbfe-5016-44ff-a080-6e3eea61bbd6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 514
        }
      },
      "source": [
        "datasets, labels = create_data()\n",
        "train_data = pd.DataFrame(datasets, columns=labels)\n",
        "train_data"
      ],
      "execution_count": 68,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>年龄</th>\n",
              "      <th>有工作</th>\n",
              "      <th>有自己的房子</th>\n",
              "      <th>信贷情况</th>\n",
              "      <th>类别</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>青年</td>\n",
              "      <td>否</td>\n",
              "      <td>否</td>\n",
              "      <td>一般</td>\n",
              "      <td>否</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>青年</td>\n",
              "      <td>否</td>\n",
              "      <td>否</td>\n",
              "      <td>好</td>\n",
              "      <td>否</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>青年</td>\n",
              "      <td>是</td>\n",
              "      <td>否</td>\n",
              "      <td>好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>青年</td>\n",
              "      <td>是</td>\n",
              "      <td>是</td>\n",
              "      <td>一般</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>青年</td>\n",
              "      <td>否</td>\n",
              "      <td>否</td>\n",
              "      <td>一般</td>\n",
              "      <td>否</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>中年</td>\n",
              "      <td>否</td>\n",
              "      <td>否</td>\n",
              "      <td>一般</td>\n",
              "      <td>否</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>中年</td>\n",
              "      <td>否</td>\n",
              "      <td>否</td>\n",
              "      <td>好</td>\n",
              "      <td>否</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>中年</td>\n",
              "      <td>是</td>\n",
              "      <td>是</td>\n",
              "      <td>好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>中年</td>\n",
              "      <td>否</td>\n",
              "      <td>是</td>\n",
              "      <td>非常好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>中年</td>\n",
              "      <td>否</td>\n",
              "      <td>是</td>\n",
              "      <td>非常好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>10</th>\n",
              "      <td>老年</td>\n",
              "      <td>否</td>\n",
              "      <td>是</td>\n",
              "      <td>非常好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>11</th>\n",
              "      <td>老年</td>\n",
              "      <td>否</td>\n",
              "      <td>是</td>\n",
              "      <td>好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>12</th>\n",
              "      <td>老年</td>\n",
              "      <td>是</td>\n",
              "      <td>否</td>\n",
              "      <td>好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>13</th>\n",
              "      <td>老年</td>\n",
              "      <td>是</td>\n",
              "      <td>否</td>\n",
              "      <td>非常好</td>\n",
              "      <td>是</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>14</th>\n",
              "      <td>老年</td>\n",
              "      <td>否</td>\n",
              "      <td>否</td>\n",
              "      <td>一般</td>\n",
              "      <td>否</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "    年龄 有工作 有自己的房子 信贷情况 类别\n",
              "0   青年   否      否   一般  否\n",
              "1   青年   否      否    好  否\n",
              "2   青年   是      否    好  是\n",
              "3   青年   是      是   一般  是\n",
              "4   青年   否      否   一般  否\n",
              "5   中年   否      否   一般  否\n",
              "6   中年   否      否    好  否\n",
              "7   中年   是      是    好  是\n",
              "8   中年   否      是  非常好  是\n",
              "9   中年   否      是  非常好  是\n",
              "10  老年   否      是  非常好  是\n",
              "11  老年   否      是    好  是\n",
              "12  老年   是      否    好  是\n",
              "13  老年   是      否  非常好  是\n",
              "14  老年   否      否   一般  否"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 68
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bWbCbJSmrdkp",
        "colab_type": "code",
        "outputId": "3d16ea53-103e-4ad9-c3d6-c40b86614406",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 287
        }
      },
      "source": [
        "datasets"
      ],
      "execution_count": 53,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[['青年', '否', '否', '一般', '否'],\n",
              " ['青年', '否', '否', '好', '否'],\n",
              " ['青年', '是', '否', '好', '是'],\n",
              " ['青年', '是', '是', '一般', '是'],\n",
              " ['青年', '否', '否', '一般', '否'],\n",
              " ['中年', '否', '否', '一般', '否'],\n",
              " ['中年', '否', '否', '好', '否'],\n",
              " ['中年', '是', '是', '好', '是'],\n",
              " ['中年', '否', '是', '非常好', '是'],\n",
              " ['中年', '否', '是', '非常好', '是'],\n",
              " ['老年', '否', '是', '非常好', '是'],\n",
              " ['老年', '否', '是', '好', '是'],\n",
              " ['老年', '是', '否', '好', '是'],\n",
              " ['老年', '是', '否', '非常好', '是'],\n",
              " ['老年', '否', '否', '一般', '否']]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 53
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9zcwdE1uiTOO",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "30378d95-6bb8-4d63-b57b-33f70b357b3e"
      },
      "source": [
        "labels"
      ],
      "execution_count": 54,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['年龄', '有工作', '有自己的房子', '信贷情况', '类别']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 54
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UP3X4BaVrgYQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "d = {'青年':1, '中年':2, '老年':3, '一般':1, '好':2, '非常好':3, '是':0, '否':1}\n",
        "data = []\n",
        "for i in range(15):\n",
        "    tmp = []\n",
        "    t = datasets[i]\n",
        "    for tt in t:\n",
        "        tmp.append(d[tt])\n",
        "    data.append(tmp)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5tV-TRIQqftJ",
        "colab_type": "code",
        "outputId": "16be14d9-1d5d-4080-db56-0dfcf329830f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 287
        }
      },
      "source": [
        "data = np.array(data);data"
      ],
      "execution_count": 56,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[1, 1, 1, 1, 1],\n",
              "       [1, 1, 1, 2, 1],\n",
              "       [1, 0, 1, 2, 0],\n",
              "       [1, 0, 0, 1, 0],\n",
              "       [1, 1, 1, 1, 1],\n",
              "       [2, 1, 1, 1, 1],\n",
              "       [2, 1, 1, 2, 1],\n",
              "       [2, 0, 0, 2, 0],\n",
              "       [2, 1, 0, 3, 0],\n",
              "       [2, 1, 0, 3, 0],\n",
              "       [3, 1, 0, 3, 0],\n",
              "       [3, 1, 0, 2, 0],\n",
              "       [3, 0, 1, 2, 0],\n",
              "       [3, 0, 1, 3, 0],\n",
              "       [3, 1, 1, 1, 1]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 56
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sN169YUn2LvE",
        "colab_type": "code",
        "outputId": "b0561124-c930-4706-fed3-095308e4f53f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "data.shape"
      ],
      "execution_count": 57,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(15, 5)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 57
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oN7QSJC72UN-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X, y = data[:,:-1], data[:, -1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1KsMqBec5Cwb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 熵\n",
        "def entropy(y):\n",
        "    N = len(y)\n",
        "    count = []\n",
        "    for value in set(y):\n",
        "        count.append(len(y[y == value]))\n",
        "    count = np.array(count)\n",
        "    entro = -np.sum((count / N) * (np.log2(count / N)))\n",
        "    return entro"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DWb2n4RcDflB",
        "colab_type": "code",
        "outputId": "b01d7bde-33c9-467a-bf65-28d8c4d3e77f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "entropy(y)"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.9709505944546686"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ApbpfKpcxw6y",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 条件熵\n",
        "def cond_entropy(X, y, cond):\n",
        "    N = len(y)\n",
        "    cond_X = X[:, cond]\n",
        "    tmp_entro = []\n",
        "    for val in set(cond_X):\n",
        "        tmp_y = y[np.where(cond_X == val)]\n",
        "        tmp_entro.append(len(tmp_y)/N * entropy(tmp_y))\n",
        "    cond_entro = sum(tmp_entro)\n",
        "    return cond_entro"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NF-g7udFK5qN",
        "colab_type": "code",
        "outputId": "9ddf18ff-6ddc-48fb-d766-4dc2c5c4034e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "cond_entropy(X, y, 0)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.8879430945988998"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QXrKL-4mS3Q5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 信息增益\n",
        "def info_gain(X, y, cond):\n",
        "    return entropy(y) - cond_entropy(X, y, cond)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KjLX_NtqezON",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 信息增益比\n",
        "def info_gain_ratio(X, y, cond):\n",
        "    return (entropy(y) - cond_entropy(X, y, cond))/cond_entropy(X, y, cond)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kKY7AVPeF4kh",
        "colab_type": "code",
        "outputId": "670c66c9-8f2c-46b0-b626-633f7fca4cd8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# A1, A2, A3, A4 =》年龄 工作 房子 信贷\n",
        "# 信息增益\n",
        "\n",
        "gain_a1 = info_gain(X, y, 0);gain_a1"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.08300749985576883"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VVTUqG4tSgwn",
        "colab_type": "code",
        "outputId": "72b043b5-a4c0-42db-b12d-5a31f577ef80",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "gain_a2 = info_gain(X, y, 1);gain_a2"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.32365019815155627"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "242jN12HSj_F",
        "colab_type": "code",
        "outputId": "a620d840-ac93-4adb-da8e-151bbe04e95c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "gain_a3 = info_gain(X, y, 2);gain_a3"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.4199730940219749"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m9prl_iaSmM1",
        "colab_type": "code",
        "outputId": "b440c3ed-8bc7-4f36-caf5-d13524d43957",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "gain_a4 = info_gain(X, y, 3);gain_a4"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.36298956253708536"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eIuVibAjpXSr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def best_split(X,y, method='info_gain'):\n",
        "    \"\"\"根据method指定的方法使用信息增益或信息增益比来计算各个维度的最大信息增益（比），返回特征的axis\"\"\"\n",
        "    _, M = X.shape\n",
        "    info_gains = []\n",
        "    if method == 'info_gain':\n",
        "        split = info_gain\n",
        "    elif method == 'info_gain_ratio':\n",
        "        split = info_gain_ratio\n",
        "    else:\n",
        "        print('No such method')\n",
        "        return\n",
        "    for i in range(M):\n",
        "        tmp_gain = split(X, y, i)\n",
        "        info_gains.append(tmp_gain)\n",
        "    best_feature = np.argmax(info_gains)\n",
        "    \n",
        "    return best_feature"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tr6ckR8wriYm",
        "colab_type": "code",
        "outputId": "d2db3308-ce72-4f5d-c74e-893944909892",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "best_split(X,y)"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "2"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iv2hm3ueTKa6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def majorityCnt(y):\n",
        "    \"\"\"当特征使用完时，返回类别数最多的类别\"\"\"\n",
        "    unique, counts = np.unique(y, return_counts=True)\n",
        "    max_idx = np.argmax(counts)\n",
        "    return unique[max_idx]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FXlY9UPxT80q",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "e356a964-3f83-40b4-ed66-41fae5266a89"
      },
      "source": [
        "majorityCnt(y)"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": true,
        "id": "EDx9vrfcn8MQ",
        "colab_type": "text"
      },
      "source": [
        "#### ID3, C4.5算法\n",
        "\n",
        "例5.3"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kpgCEMIKRo8_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class DecisionTreeClassifer:\n",
        "    \"\"\"\n",
        "    决策树生成算法，\n",
        "    method指定ID3或C4.5,两方法唯一不同在于特征选择方法不同\n",
        "    info_gain:       信息增益即ID3\n",
        "    info_gain_ratio: 信息增益比即C4.5\n",
        "    \n",
        "    \n",
        "    \"\"\"\n",
        "    def __init__(self, threshold, method='info_gain'):\n",
        "        self.threshold = threshold\n",
        "        self.method = method\n",
        "        \n",
        "    def fit(self, X, y, labels):\n",
        "        labels = labels.copy()\n",
        "        M, N = X.shape\n",
        "        if len(np.unique(y)) == 1:\n",
        "            return y[0]\n",
        "        \n",
        "        if N == 1:\n",
        "            return majorityCnt(y)\n",
        "        \n",
        "        bestSplit = best_split(X,y, method=self.method)\n",
        "        bestFeaLable = labels[bestSplit]\n",
        "        Tree = {bestFeaLable: {}}\n",
        "        del (labels[bestSplit])\n",
        "        \n",
        "        feaVals = np.unique(X[:, bestSplit])\n",
        "        for val in feaVals:\n",
        "            idx = np.where(X[:, bestSplit] == val)\n",
        "            sub_X = X[idx]\n",
        "            sub_y = y[idx]\n",
        "            sub_labels = labels\n",
        "            Tree[bestFeaLable][val] = self.fit(sub_X, sub_y, sub_labels)\n",
        "            \n",
        "        return Tree"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8k4cgeqBn8MQ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "f3c4ca27-9f09-4d42-bd58-3afa41cc32e0"
      },
      "source": [
        "My_Tree = DecisionTreeClassifer(threshold=0.1)\n",
        "My_Tree.fit(X, y, labels)"
      ],
      "execution_count": 69,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'有自己的房子': {0: 0, 1: {'有工作': {0: 0, 1: 1}}}}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 69
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XaGNaDfAoivJ",
        "colab_type": "text"
      },
      "source": [
        "#### CART树"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yXTTfkLCmsdP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class CART:\n",
        "    \"\"\"CART树\"\"\"\n",
        "    def __init__(self, ):\n",
        "        \"to be continue\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6nxK8duGo37e",
        "colab_type": "text"
      },
      "source": [
        "#### 决策树的剪枝"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "N79jPbwWo6rv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\"to be continue\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gop3ocYDn8MZ",
        "colab_type": "text"
      },
      "source": [
        "---\n",
        "\n",
        "## sklearn.tree.DecisionTreeClassifier\n",
        "\n",
        "### criterion : string, optional (default=”gini”)\n",
        "The function to measure the quality of a split. Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nxE7F4sqn8Ma",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# data\n",
        "def create_data():\n",
        "    iris = load_iris()\n",
        "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
        "    df['label'] = iris.target\n",
        "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
        "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
        "    # print(data)\n",
        "    return data[:,:2], data[:,-1]\n",
        "\n",
        "X, y = create_data()\n",
        "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LyqL3F8un8Mc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from sklearn.tree import DecisionTreeClassifier\n",
        "\n",
        "from sklearn.tree import export_graphviz\n",
        "import graphviz"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nNDjw1Phn8Me",
        "colab_type": "code",
        "outputId": "d2dd416b-1a48-4564-c53f-c801416c6df0",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 125
        }
      },
      "source": [
        "clf = DecisionTreeClassifier()\n",
        "clf.fit(data[:,:-1], data[:,-1])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n",
              "                       max_features=None, max_leaf_nodes=None,\n",
              "                       min_impurity_decrease=0.0, min_impurity_split=None,\n",
              "                       min_samples_leaf=1, min_samples_split=2,\n",
              "                       min_weight_fraction_leaf=0.0, presort=False,\n",
              "                       random_state=None, splitter='best')"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RsB_iiLZn8Mh",
        "colab_type": "code",
        "outputId": "af553ea2-ec41-496d-ceb9-a750be3d8088",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "clf.predict(np.array([1, 1, 0, 1]).reshape(1,-1)) # A"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([0])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Sd2ScBfQu1Bo",
        "colab_type": "code",
        "outputId": "cbdf0ec2-1fd6-48a3-ae24-f4de9d6a442b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "clf.predict(np.array([2, 0, 1, 2]).reshape(1,-1)) # B"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([0])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0E9mMz34u1a3",
        "colab_type": "code",
        "outputId": "7078b7fc-3322-4ee5-9e5f-6f8d575347f6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "clf.predict(np.array([2, 1, 0, 1]).reshape(1,-1)) # C"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([0])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 30
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rmZHZjbYn8Mm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tree_pic = export_graphviz(clf, out_file=\"mytree.pdf\")\n",
        "with open('mytree.pdf') as f:\n",
        "    dot_graph = f.read()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AeRk07sYn8Mq",
        "colab_type": "code",
        "outputId": "c526e824-6d1d-4f3e-f231-53b9d6447d90",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 379
        }
      },
      "source": [
        "graphviz.Source(dot_graph)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<graphviz.files.Source at 0x7fdc581b9080>"
            ],
            "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: Tree Pages: 1 -->\n<svg width=\"272pt\" height=\"269pt\"\n viewBox=\"0.00 0.00 272.00 269.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 265)\">\n<title>Tree</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-4,4 -4,-265 268,-265 268,4 -4,4\"/>\n<!-- 0 -->\n<g id=\"node1\" class=\"node\">\n<title>0</title>\n<polygon fill=\"none\" stroke=\"#000000\" points=\"151,-261 56,-261 56,-193 151,-193 151,-261\"/>\n<text text-anchor=\"middle\" x=\"103.5\" y=\"-245.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">X[2] &lt;= 0.5</text>\n<text text-anchor=\"middle\" x=\"103.5\" y=\"-230.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.48</text>\n<text text-anchor=\"middle\" x=\"103.5\" y=\"-215.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">samples = 15</text>\n<text text-anchor=\"middle\" x=\"103.5\" y=\"-200.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">value = [9, 6]</text>\n</g>\n<!-- 1 -->\n<g id=\"node2\" class=\"node\">\n<title>1</title>\n<polygon fill=\"none\" stroke=\"#000000\" points=\"95,-149.5 0,-149.5 0,-96.5 95,-96.5 95,-149.5\"/>\n<text text-anchor=\"middle\" x=\"47.5\" y=\"-134.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n<text text-anchor=\"middle\" x=\"47.5\" y=\"-119.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">samples = 6</text>\n<text text-anchor=\"middle\" x=\"47.5\" y=\"-104.3\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">value = [6, 0]</text>\n</g>\n<!-- 0&#45;&gt;1 -->\n<g id=\"edge1\" class=\"edge\">\n<title>0&#45;&gt;1</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M85.1635,-192.9465C79.2324,-181.9316 72.6419,-169.6922 66.6532,-158.5703\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"69.6146,-156.6875 61.7919,-149.5422 63.4513,-160.0063 69.6146,-156.6875\"/>\n<text text-anchor=\"middle\" x=\"54.6082\" y=\"-169.7878\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">True</text>\n</g>\n<!-- 2 -->\n<g id=\"node3\" class=\"node\">\n<title>2</title>\n<polygon fill=\"none\" stroke=\"#000000\" points=\"208,-157 113,-157 113,-89 208,-89 208,-157\"/>\n<text text-anchor=\"middle\" x=\"160.5\" y=\"-141.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">X[1] &lt;= 0.5</text>\n<text text-anchor=\"middle\" x=\"160.5\" y=\"-126.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.444</text>\n<text text-anchor=\"middle\" x=\"160.5\" y=\"-111.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">samples = 9</text>\n<text text-anchor=\"middle\" x=\"160.5\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">value = [3, 6]</text>\n</g>\n<!-- 0&#45;&gt;2 -->\n<g id=\"edge2\" class=\"edge\">\n<title>0&#45;&gt;2</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M122.1639,-192.9465C126.888,-184.3271 132.0231,-174.9579 136.954,-165.9611\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"140.0571,-167.5815 141.7941,-157.13 133.9186,-164.2171 140.0571,-167.5815\"/>\n<text text-anchor=\"middle\" x=\"148.7999\" y=\"-177.4283\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">False</text>\n</g>\n<!-- 3 -->\n<g id=\"node4\" class=\"node\">\n<title>3</title>\n<polygon fill=\"none\" stroke=\"#000000\" points=\"151,-53 56,-53 56,0 151,0 151,-53\"/>\n<text text-anchor=\"middle\" x=\"103.5\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n<text text-anchor=\"middle\" x=\"103.5\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">samples = 3</text>\n<text text-anchor=\"middle\" x=\"103.5\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">value = [3, 0]</text>\n</g>\n<!-- 2&#45;&gt;3 -->\n<g id=\"edge3\" class=\"edge\">\n<title>2&#45;&gt;3</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M140.4039,-88.9777C135.2656,-80.2786 129.727,-70.9018 124.5425,-62.1247\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"127.399,-60.0786 119.2996,-53.2485 121.3719,-63.6387 127.399,-60.0786\"/>\n</g>\n<!-- 4 -->\n<g id=\"node5\" class=\"node\">\n<title>4</title>\n<polygon fill=\"none\" stroke=\"#000000\" points=\"264,-53 169,-53 169,0 264,0 264,-53\"/>\n<text text-anchor=\"middle\" x=\"216.5\" y=\"-37.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">gini = 0.0</text>\n<text text-anchor=\"middle\" x=\"216.5\" y=\"-22.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">samples = 6</text>\n<text text-anchor=\"middle\" x=\"216.5\" y=\"-7.8\" font-family=\"Times,serif\" font-size=\"14.00\" fill=\"#000000\">value = [0, 6]</text>\n</g>\n<!-- 2&#45;&gt;4 -->\n<g id=\"edge4\" class=\"edge\">\n<title>2&#45;&gt;4</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M180.2435,-88.9777C185.2917,-80.2786 190.7331,-70.9018 195.8266,-62.1247\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"198.9855,-63.6544 200.9776,-53.2485 192.9311,-60.1409 198.9855,-63.6544\"/>\n</g>\n</g>\n</svg>\n"
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dlk_DsGByMix",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}